#include <stdio.h>
#include <math.h>
#include "timer.h"

#define NX 256
#define NY 256
#define NZ 256
#define NSTEPS 2000

//
// evolve a 3D scalar wave equation
// the two fields we'll evolve are:
//    f - the field
//    g - the time derivative of the field
//

int main(int argc, char *argv[]) {

   int i,j,k,n;

   int nx = NX;
   int ny = NY;
   int nz = NZ;
   int nsteps = NSTEPS;

   float x[NX][NY][NZ],y[NX][NY][NZ],z[NX][NY][NZ];
   float f[NX][NY][NZ],g[NX][NY][NZ];
   float fp[NX][NY][NZ],gp[NX][NY][NZ];

   float dx = 2.0f/(nx-1);
   float dy = 2.0f/(ny-1);
   float dz = 2.0f/(nz-1);
   float dt = 0.00000005f; // in order for the system to be numerically dt < dx!!!

   // initialize the grid to run from -1 to 1 in each direction
   for (i=0; i<NX; i++) {
      for (j=0; j<NY; j++) {
         for (k=0; k<NZ; k++) {
            x[i][j][k] = -1.0f + (i)*dx;
            y[i][j][k] = -1.0f + (j)*dy;
            z[i][j][k] = -1.0f + (k)*dz;
         }
      }
   } 

   // initialize the field to be a gaussian
   for (i=0; i<NX; i++) {
      for (j=0; j<NY; j++) {
         for (k=0; k<NZ; k++) {
            f[i][j][k] = 0.2f * exp( - ( x[i][j][k]*x[i][j][k] + 
                                         y[i][j][k]*y[i][j][k] + 
                                         z[i][j][k]*z[i][j][k] )/0.05f);
            g[i][j][k] = 0.0f;
         }
      }
   } 

   // output the initial data when there are an even number of points, 
   // pick a line closest to a coordinate axis
   FILE *fPtr = fopen("wave3d.xline", "w");
   for (i=0; i<NX; i++) {
      fprintf(fPtr,"%5.3f %10.6e\n",x[i][NY/2][NZ/2],f[i][NY/2][NZ/2]);
   }
   fprintf(fPtr,"\n");

   float step = 0.0f;
   int printevery = 20;
   printf("step = %9.6f \n",step);

   StartTimer();

   for (n=0; n<nsteps; n++) {

      step = step + dt;
    
      if (((n+1)%printevery)==0)
         printf("step = %9.6f \n",step);
    
      #pragma omp parallel private(i,j,k)
      {

         // predictor
         #pragma omp for schedule(static) 
         for (i=0; i<NX; i++) {
            for (j=0; j<NY; j++) {
               for (k=0; k<NZ; k++) {
                  fp[i][j][k] = f[i][j][k] + dt * g[i][j][k];
               }
            }
         } 
      
         // static boundaries
         for (j=0; j<NY; j++) {
            for (k=0; k<NZ; k++) {
               gp[0][j][k] =  g[0][j][k];
               gp[NX-1][j][k] = g[NX-1][j][k];
            }
         } 
      
         for (i=0; i<NX; i++) {
            for (k=0; k<NZ; k++) {
               gp[i][0][k] =  g[i][0][k];
               gp[i][NY-1][k] = g[i][NY-1][k];
            }
         } 
      
         for (i=0; i<NX; i++) {
            for (j=0; j<NY; j++) {
               gp[i][j][0] =  g[i][j][0];
               gp[i][j][NZ-1] = g[i][j][NZ-1];
            }
         } 
      
         // use the predictor to update gp
         #pragma omp for schedule(static) 
         for (i=1; i<NX-1; i++) {
            for (j=1; j<NY-1; j++) {
               for (k=1; k<NZ-1; k++) {
                  gp[i][j][k] = g[i][j][k] + dt*( 
                     (fp[i+1][j][k]-2.0f*fp[i][j][k]+fp[i-1][j][k])/dx/dx +
                     (fp[i][j+1][k]-2.0f*fp[i][j][k]+fp[i][j-1][k])/dy/dy +
                     (fp[i][j][k+1]-2.0f*fp[i][j][k]+fp[i][j][k-1])/dz/dz);
               }
            }
         } 
      
         // use the average g's to update f
         #pragma omp for schedule(static) 
         for (i=0; i<NX; i++) {
            for (j=0; j<NY; j++) {
               for (k=0; k<NZ; k++) {
                  fp[i][j][k] = f[i][j][k] + dt * (0.5f * (g[i][j][k]+gp[i][j][k]));
               }
            }
         } 
      
         // now update all the variables
         #pragma omp for schedule(static) 
         for (i=0; i<NX; i++) {
            for (j=0; j<NY; j++) {
               for (int k=0; k<NZ; k++) {
                  f[i][j][k] = fp[i][j][k];
                  g[i][j][k] = gp[i][j][k];
               }
            }
         } 
      
      } // pragma omp
    
      if (((n+1)%printevery)==0) {
         for (i=0; i<NX; i++) {
            fprintf(fPtr,"%5.3f %10.6e\n",x[i][NY/2][NZ/2],f[i][NY/2][NZ/2]);
         }
         fprintf(fPtr,"\n");
      }
    
    
   } // for nsteps

   float computeTime = GetTimer();
   printf("Compute time: %f seconds\n", computeTime / 1000.0f);

   exit(0);
}
